import torch
from Vae.record_vae_and_masks_state import record_encodings_and_mask_state
from data_utils.vae_dataset import ImageDataset
import wandb
from models.cnn_vae import CNNVAE
from torch.utils.data import DataLoader
import torch.optim as optim
import glob
import random
import numpy as np
from tqdm import tqdm
import os
import time
from PIL import Image
import torchvision
from data_utils.data_util import _CustomDataParallel
import torch.nn as nn


def train(model, optimizer, epochs, val_freq, device, N_batches_per_epoch=1000):
    model.train()
    best_val_loss = float('inf')
    for epoch in tqdm(range(epochs)):
        overall_loss = []
        all_recon_loss = []
        all_kl_loss = []
        N = len(train_loader)
        print("Total number of batches: ", N, "Number of batches per epoch: ", N_batches_per_epoch)
        # Create an iterator for the train_loader
        batch_iterator = iter(train_loader)
        for batch_idx in range(N_batches_per_epoch):
            try:
                x = next(batch_iterator)
            except StopIteration:
                # This block is in case N_batches_per_epoch exceeds the available batches
                print(f"Reached the end of the dataset at batch index {batch_idx}, less than the specified N_batches_per_epoch")
                break
            x = x.to(device)
            optimizer.zero_grad()

            x_hat, mean, log_var = model(x)
            recon_loss, kl_loss = model.get_vae_loss(x, x_hat, mean, log_var)
            loss = recon_loss + 0.1 * kl_loss
            
            all_recon_loss.append(recon_loss.item() )
            all_kl_loss.append(kl_loss.item() )
            overall_loss.append(loss.item() )
            
            loss.backward()
            optimizer.step()
            print("\t\tBatch", batch_idx, "\tLoss: ", loss.item(), "\tRecon Loss: ", recon_loss.item(), "\tKL Loss: ", kl_loss.item())
        print("\tEpoch", epoch + 1, "\tAverage Loss: ", np.mean(overall_loss))
        log_dict = {"Epoch": epoch, 
                    "Training Loss": np.mean(overall_loss),
                    "Training Reconstruction Loss": np.mean(all_recon_loss),
                    "Training KL Loss": np.mean(all_kl_loss)}

        # Log validation loss in W&B
        if epoch > 0 and epoch % val_freq == 0:
            # Save Model Checkpoint
            torch.save(model.state_dict(), os.path.join(log_dir, f'model_checkpoint_{epoch}.pth'))

            with torch.no_grad():
                val_log_dict = validate(model, val_loader, device)
                log_dict.update(val_log_dict)
                val_loss = val_log_dict["Validation Loss"]
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), os.path.join(log_dir, 'best_model.pth'))
                    print("\tBest model saved at epoch : ", epoch)
                
                plots = plot_latent_space(model)
                log_dict.update({"Latent Space Images": wandb.Image(plots)})
        # Log metrics in W&B
        wandb.log(log_dict)

def train_fit_linear(model, optimizer, epochs, val_freq, device, N_batches_per_epoch=1000):
    model.train()
    best_val_loss = float('inf')
    for epoch in tqdm(range(epochs)):
        overall_loss = []
        all_recon_loss = []
        all_kl_loss = []
        N = len(train_loader)

        # Create an iterator for the train_loader
        batch_iterator = iter(train_loader)
        
        for batch_idx in range(N_batches_per_epoch):
            try:
                x, target = next(batch_iterator)
            except StopIteration:
                batch_iterator = iter(train_loader)
                x, target = next(batch_iterator)
        # for batch_idx, (x, target) in enumerate(train_loader):
            x, target = x.to(device), target.to(device)
            optimizer.zero_grad()

            pred = model.forward_fit_linear(x)
            loss = nn.functional.mse_loss(pred, target)
            overall_loss.append(loss.item() )
            
            loss.backward()
            optimizer.step()
            print("\t\tBatch", batch_idx, "\tLoss: ", loss.item())
        print("\tEpoch", epoch + 1, "\tAverage Loss: ", np.mean(overall_loss))
        log_dict = {"Epoch": epoch, 
                    "Training Loss": np.mean(overall_loss)}

        # Log validation loss in W&B
        if epoch > 0 and epoch % val_freq == 0:
            # Save Model Checkpoint
            torch.save(model.state_dict(), os.path.join(log_dir, f'model_checkpoint_{epoch}.pth'))

            with torch.no_grad():
                val_log_dict = validate_fit_linear(model, val_loader, device)
                log_dict.update(val_log_dict)
                val_loss = val_log_dict["Validation Loss"]
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    torch.save(model.state_dict(), os.path.join(log_dir, 'best_model.pth'))
                    print("\tBest model saved at epoch : ", epoch)
        # Log metrics in W&B
        wandb.log(log_dict)

def plot_latent_space(model, scale=1.0, n=10, digit_size=64):
    model.eval()
    # display a n*n 2D manifold of digits
    figure = np.zeros((digit_size * n, digit_size * n, 3))

    # construct a grid 
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = torch.randn(1, latent_dim).to(device)
            x_decoded = model.decode(z_sample)
            recon = x_decoded[0, -3:, :, :].detach().cpu().permute(1, 2, 0).numpy()
            recon = recon * 255.0
            figure[i * digit_size : (i + 1) * digit_size, j * digit_size : (j + 1) * digit_size, :] = recon

    image = Image.fromarray(figure.astype(np.uint8))
    model.train()
    return image


def validate(model, val_loader, device):
    model.eval()
    overall_loss = []
    all_recon_loss = []
    all_kl_loss = []
    with torch.no_grad():
        N = len(val_loader)
        for batch_idx, x in enumerate(val_loader):
            x = x.to(device)
            x_hat, mean, log_var = model(x)
            recon_loss, kl_loss = model.get_vae_loss(x, x_hat, mean, log_var)
            loss = recon_loss + 0.1 * kl_loss
            
            all_recon_loss.append(recon_loss.item() )
            all_kl_loss.append(kl_loss.item() )
            overall_loss.append(loss.item() )
        
        # Create a grid of 5x5 reconstructions
        test_grid = torchvision.utils.make_grid(x[:25, -3:, :, :], nrow=5)
        test_image = torchvision.transforms.ToPILImage()(test_grid)
        grid = torchvision.utils.make_grid(x_hat[:25, -3:, :, :], nrow=5)
        image = torchvision.transforms.ToPILImage()(grid)

        print("\tValidation Loss: ", np.mean(overall_loss))
        log_dict = {"Validation Loss": np.mean(overall_loss),
                    "Validation Reconstruction Loss": np.mean(all_recon_loss),
                    "Validation KL Loss": np.mean(all_kl_loss),
                    "Original Images": wandb.Image(test_image),
                    "Reconstructed Images": wandb.Image(image)}
    model.train()
    return log_dict

def validate_fit_linear(model, val_loader, device):
    model.eval()
    overall_loss = []
    with torch.no_grad():
        N = len(val_loader)
        for batch_idx, (x, target) in enumerate(val_loader):
            x, target = x.to(device), target.to(device)
            pred = model.forward_fit_linear(x)
            loss = nn.functional.mse_loss(pred, target)
            overall_loss.append(loss.item() )
        
        print("\tValidation Loss: ", np.mean(overall_loss))
        log_dict = {"Validation Loss": np.mean(overall_loss)}
    model.train()
    return log_dict

if __name__ == "__main__":
    import sys
    import os

    # Add parent folder to sys.path
    parent_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    sys.path.append(parent_folder)


    # Hyperparameters
    batch_size = 1024
    latent_dim = 10
    frame_stack = 5
    lr = 1e-2
    epochs = 20
    N_batches_per_epoch = 100
    val_freq = 1
    model_path = ''
    device = 'cuda'
    use_flow = False
    fit_linear = False

    # Create a new directory for saving logs
    current_time = time.strftime("%Y%m%d-%H%M%S")
    run_name = current_time + '_' + f"default_cnn_vae_frame5_z{latent_dim}_full_data"
    log_dir = os.path.join('data', run_name)
    os.makedirs(log_dir, exist_ok=True)
    # Initialize W&B project
    # wandb.init(project="causal_vae", entity="carltheq", group='full dataset run')
    wandb.init(project="causal_vae")  ## TODO: Add entity and group
    wandb.run.name = run_name

    # Dataset
    # base_dir = '/data/carlq/research/ac_infer/datasets/box2d_default'
    base_dir = '/data/carlq/research/ac_infer/datasets/box2d_default'
    file_paths = glob.glob(base_dir+'/*/state*_*.png')
    from Record.file_management import read_obj_dumps
    obj_data = read_obj_dumps(base_dir, i=0, rng=-1, filename='object_dumps.txt')


    # Setup random seed
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42) 


    # Load Data
    train_loader = DataLoader(ImageDataset(file_paths, obj_data, frame_stack=frame_stack, split='train', use_flow=use_flow, fit_linear=fit_linear), batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(ImageDataset(file_paths, obj_data, frame_stack=frame_stack, split='val', use_flow=use_flow, fit_linear=fit_linear), batch_size=batch_size, shuffle=False, num_workers=4)

    # Initialize Model
    model = CNNVAE(latent_dim=latent_dim, nc=3 * frame_stack if not use_flow else 5, fit_linear=fit_linear).to(device)
    model = _CustomDataParallel(model)
    if model_path != '':
        model.load_state_dict(torch.load(model_path))
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Train Model
    # train_fit_linear(model, optimizer, epochs=epochs, val_freq=val_freq, device=device)
    train(model, optimizer, epochs=epochs, val_freq=val_freq, device=device, N_batches_per_epoch=N_batches_per_epoch)

    ### Encoding
    dataset = ImageDataset(file_paths, obj_data, frame_stack=frame_stack, split='full', ret_frame_info=True, use_flow=use_flow, fit_linear=fit_linear)
    record_encodings_and_mask_state(model, obj_data, dataset, frame_stack, save_rollouts='/data/carlq/research/ac_infer/datasets/box2d_default')
    # Test reconstruction
    # ds = ImageDataset(file_paths, obj_data, split='train', frame_stack=frame_stack, use_flow=use_flow)
    # x_hat, mean, log_var = model(ds[0].unsqueeze(0).to(device))
    # decoded = x_hat[0].detach().cpu().permute(1, 2, 0).numpy()
    # decoded = decoded * 255.0
    # from PIL import Image
    # image = Image.fromarray(decoded.astype(np.uint8))
    # image.save("test_recon.png")
